# Copyright 2024 The OpenXLA Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Asserts all tags in XLA are documented.

`bazel query //xla/... --output=build` is read from stdin, and then we check
all tags are present in the `_TAGS_TO_DOCUMENTATION_MAP`. Ideally we would parse
using
https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/build.proto
but this is not possible due to XLA's old protobuf version. So we parse by hand
instead.
"""
import logging
import sys
from typing import Set

_TAGS_TO_DOCUMENTATION_MAP = {
    # Tags that Bazel recognizes
    "local": "https://bazel.build/reference/be/common-definitions",
    "manual": "https://bazel.build/reference/be/common-definitions",
    "large": "Conventional tag for `test_suites` of large tests",
    "__PYTHON_RULES_MIGRATION_DO_NOT_USE_WILL_BREAK__": "Internal bazel tag",
    # Various disable tags (currently recognized by OpenXLA CI)
    "no_oss": "Test is disabled on OpenXLA CI.",
    "no_mac": "Disabled on MacOS.",
    "no_windows": "Disabled on Windows.",
    "no_mac_arm64": "Disabled on ARM MacOS.",
    "not_run:arm": (
        "Not ran on ARM. Currently OpenXLA CI doesn't make a distinction and"
        " doesn't build things tagged with this either."
    ),
    # Various disable tags (currently *unrecognized* by OpenXLA CI)
    "notap": "Internal tag which disables the test. Not used on OpenXLA CI.",
    "nosan": "Disabled under all sanitizers. Not used on OpenXLA CI.",
    "noasan": "Disabled under asan. Not used on OpenXLA CI.",
    "nomsan": "Disabled under msan. Not used on OpenXLA CI.",
    "notsan": "Disabled under tsan. Not used on OpenXLA CI",
    "nobuilder": "Not built internally.",
    "nozapfhahn": "Internal tag. Disables gathering coverage",
    "optonly": "Should only be tested with -c opt",
    "nodebug": "Should not be tested in debug builds.",
    "config-cuda-only": (
        "Meaningless in OSS as all GPU tests are built with `--config=cuda`"
    ),
    # GPU tags
    "requires-gpu": (
        "Test requires GPU to execute. Fallback if neither CUDA nor ROCm is"
        " specified."
    ),
    "requires-gpu-amd": "Test requires AMD GPU to execute",
    "requires-gpu-intel": "Test requires Intel GPU to execute",
    "requires-gpu-nvidia": "Test requires NVIDIA GPU to execute",
    "requires-gpu-nvidia:2": "Test needs 2 NVIDIA GPUs to run",
    "requires-gpu-sm60-only": "Requires exactly sm60.",
    "requires-gpu-sm70-only": "Requires exactly sm70.",
    "requires-gpu-sm80-only": "Requires exactly sm80.",
    "requires-gpu-sm90-only": "Requires exactly sm90.",
    "requires-gpu-sm100-only": "Requires exactly sm100.",
    "gpu": "Catch-all tag for targets that should be built/tested on GPU CI",
    "cpu": "Catch-all tag for targets that should be built/tested on CPU CI.",
    "cuda-only": "Targets that require the CUDA backend to be enabled.",
    "rocm-only": "Targets that require the ROCm backend to be enabled.",
    "oneapi-only": "Targets that require the oneAPI backend to be enabled.",
    "no-oneapi": "Targets that are not configured for the oneAPI backend.",
    # Below tags are generated by `xla_test`.
    "broken": "Test will be marked with other tags to disable in `xla_test`.",
    "xla_interpreter": "Uses interpreter backend.",
    "xla_cpu": "Uses CPU backend.",
    "xla_amdgpu_any": "Uses ROCm backend.",
    "xla_nvgpu_any": "Uses NVIDIA GPU backend.",
    "xla_intelgpu_any": "Uses Intel GPU backend.",
    # Below tags are emitted alongside `requires-gpu-x` tags, which is what the
    # CI actually follows. So we may not execute on an A100, and instead use an
    # L4. These tags are taken literally internally.
    "xla_p100": "Runs on a p100.",
    "xla_v100": "Runs on a v100.",
    "xla_a100": "Runs on an a100.",
    "xla_h100": "Runs on an h100.",
    "xla_b200": "Runs on a b200.",
    # Below tags are consumed by `xla_test`.
    "test_migrated_to_hlo_runner_pjrt": (
        "Adds the appropriate `xla/tests:pjrt_$BACKEND_client_registry` to the"
        " annotated `xla_test` target. Adding this tag does not synthesize"
        " additional targets."
    ),
    "multi_gpu": "Used by `xla_test` to signal that multiple GPUs are needed.",
    "multi_gpu_h100": (
        "Used by `xla_test` to signal that multiple H100s are needed."
    ),
}


def get_tags_from_line(line: str) -> Set[str]:
  if line.strip().startswith("tags = "):
    tags_list = line[10:-3]  # "tag1", "tag2"
    if tags_list.strip():
      # Remove extraneous quotes, tags like `-broken` used in test_suites,
      # and split on ", "
      return {tag.strip('-"') for tag in tags_list.split(", ")}

  return set()


def main():
  logging.basicConfig()
  logging.getLogger().setLevel(logging.INFO)

  tags = set.union(*(get_tags_from_line(line) for line in sys.stdin))

  logging.info(str(tags))

  if undocumented_tags := tags - _TAGS_TO_DOCUMENTATION_MAP.keys():
    raise ValueError(
        f"Tag(s) {undocumented_tags} are undocumented! Please document them in"
        " `build_tools/lint/tags.py`."
    )

  if unused_but_documented_tags := _TAGS_TO_DOCUMENTATION_MAP.keys() - tags:
    logging.info(
        "The following tags are documented but unused: %s. Do we expect they'll"
        " be used in the future?",
        str(unused_but_documented_tags),
    )


if __name__ == "__main__":
  main()
